{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from datasets import PartDataset\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Earphone': 0, 'Motorbike': 1, 'Rocket': 2, 'Car': 3, 'Laptop': 4, 'Cap': 5, 'Skateboard': 6, 'Lamp': 10, 'Guitar': 8, 'Bag': 9, 'Mug': 7, 'Table': 11, 'Airplane': 12, 'Pistol': 13, 'Chair': 14, 'Knife': 15}\n"
     ]
    }
   ],
   "source": [
    "d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(d.classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "levels = (np.log(2048)/np.log(2)).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "cutdim = torch.zeros((levels)).long()\n",
    "#branch = torch.zeros((2048, levels)).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def split_ps(point_set):\n",
    "    #print point_set.size()\n",
    "    num_points = point_set.size()[0]/2\n",
    "    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0] \n",
    "    dim = torch.max(diff, dim = 1)[1][0,0]\n",
    "    cut = torch.median(point_set[:,dim])[0][0]  \n",
    "    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))\n",
    "    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))\n",
    "    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))\n",
    "    \n",
    "    #if torch.numel(left_idx) > 0:\n",
    "    #    left_idx = left_idx[:,0]\n",
    "    #if torch.numel(right_idx) > 0:\n",
    "    #    right_idx = right_idx[:,0]\n",
    "    #if torch.numel(middle_idx) > 0:\n",
    "    #    middle_idx = middle_idx[:,0] \n",
    "    \n",
    "    if torch.numel(left_idx) < num_points:\n",
    "        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)\n",
    "    if torch.numel(right_idx) < num_points:\n",
    "        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)\n",
    "    \n",
    "    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)\n",
    "    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)\n",
    "    return left_ps, right_ps, dim "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.62 s, sys: 14 ms, total: 3.64 s\n",
      "Wall time: 3.65 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "points_batch = []\n",
    "cutdim_batch = []\n",
    "for i in range(15):\n",
    "    #print i\n",
    "    point_set, class_label = d[i]\n",
    "    tree = [[] for i in range(levels + 1)]\n",
    "    cutdim = [[] for i in range(levels)]\n",
    "    tree[0].append(point_set)\n",
    "    for level in range(levels):\n",
    "        for item in tree[level]:\n",
    "            left_ps, right_ps, dim = split_ps(item)\n",
    "            tree[level+1].append(left_ps)\n",
    "            tree[level+1].append(right_ps)\n",
    "            cutdim[level].append(dim)  \n",
    "            cutdim[level].append(dim)  \n",
    "    cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n",
    "    points = torch.stack(tree[-1])\n",
    "    points_batch.append(torch.unsqueeze(torch.squeeze(points), 0).transpose(2,1))\n",
    "    cutdim_batch.append(cutdim)\n",
    "#points_v = Variable(torch.unsqueeze(torch.squeeze(points), 0)).transpose(2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "points_v = Variable(torch.cat(points_batch, 0))\n",
    "cutdim_processed = []\n",
    "for i in range(len(cutdim_batch[0])):\n",
    "    cutdim_processed.append(torch.stack([item[i] for item in cutdim_batch], 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class KDNet(nn.Module):\n",
    "    def __init__(self, k = 16):\n",
    "        super(KDNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(3,8 * 3,1,1)\n",
    "        self.conv2 = nn.Conv1d(8,32 * 3,1,1)\n",
    "        self.conv3 = nn.Conv1d(32,64 * 3,1,1)\n",
    "        self.conv4 = nn.Conv1d(64,64 * 3,1,1)\n",
    "        self.conv5 = nn.Conv1d(64,64 * 3,1,1)\n",
    "        self.conv6 = nn.Conv1d(64,128 * 3,1,1)\n",
    "        self.conv7 = nn.Conv1d(128,256 * 3,1,1)\n",
    "        self.conv8 = nn.Conv1d(256,512 * 3,1,1)\n",
    "        self.conv9 = nn.Conv1d(512,512 * 3,1,1)\n",
    "        self.conv10 = nn.Conv1d(512,512 * 3,1,1)\n",
    "        self.conv11 = nn.Conv1d(512,1024 * 3,1,1)      \n",
    "        self.fc = nn.Linear(1024, k)\n",
    "\n",
    "    def forward(self, x, c):\n",
    "        def kdconv(x, dim, featdim, sel, conv):\n",
    "            batchsize = x.size(0)\n",
    "            x =  F.relu(conv(x))\n",
    "            x = x.view(-1, featdim, 3, dim)\n",
    "            x = x.view(-1, featdim, 3 * dim)\n",
    "            x = x.transpose(1,0).contiguous()\n",
    "            x = x.view(featdim, 3 * dim * batchsize)\n",
    "            #print x.size()\n",
    "            sel = Variable(sel + (torch.arange(0,dim) * 3).repeat(batchsize,1).long()).view(-1,1)\n",
    "            #print sel.size()\n",
    "            offset = Variable((torch.arange(0,batchsize) * dim * 3).repeat(dim,1).transpose(1,0).contiguous().long().view(-1,1))\n",
    "            sel = sel+offset\n",
    "            \n",
    "            if x.is_cuda:\n",
    "                sel = sel.cuda()     \n",
    "            sel = sel.squeeze()\n",
    "            \n",
    "            x = torch.index_select(x, dim = 1, index = sel)   \n",
    "            x = x.view(featdim, batchsize, dim)\n",
    "            x = x.transpose(1,0).contiguous()\n",
    "            x = x.view(-1, featdim, dim/2, 2)\n",
    "            x = torch.squeeze(torch.max(x, dim = -1)[0], 3)\n",
    "            return x      \n",
    "        \n",
    "        x1 = kdconv(x, 2048, 8, c[-1], self.conv1)\n",
    "        x2 = kdconv(x1, 1024, 32, c[-2], self.conv2)\n",
    "        x3 = kdconv(x2, 512, 64, c[-3], self.conv3)\n",
    "        x4 = kdconv(x3, 256, 64, c[-4], self.conv4)\n",
    "        x5 = kdconv(x4, 128, 64, c[-5], self.conv5)\n",
    "        x6 = kdconv(x5, 64, 128, c[-6], self.conv6)\n",
    "        x7 = kdconv(x6, 32, 256, c[-7], self.conv7)\n",
    "        x8 = kdconv(x7, 16, 512, c[-8], self.conv8)\n",
    "        x9 = kdconv(x8, 8, 512, c[-9], self.conv9)\n",
    "        x10 = kdconv(x9, 4, 512, c[-10], self.conv10)\n",
    "        x11 = kdconv(x10, 2, 1024, c[-11], self.conv11)\n",
    "        x11 = x11.view(-1,1024)\n",
    "        out = F.log_softmax(self.fc(x11))\n",
    "        return out\n",
    "        \n",
    "net = KDNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x = net(points_v, cutdim_processed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([15, 3, 2048])"
      ]
     },
     "execution_count": 138,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "points_v.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "torch.sum(x).backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       " 0\n",
       "[torch.LongTensor of size 1]"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "\n",
       "Columns 0 to 9 \n",
       "-2.7780 -2.7592 -2.7585 -2.7613 -2.7265 -2.7407 -2.7838 -2.7462 -2.7619 -2.8141\n",
       "-2.7769 -2.7602 -2.7597 -2.7613 -2.7270 -2.7394 -2.7838 -2.7458 -2.7614 -2.8118\n",
       "-2.7770 -2.7599 -2.7598 -2.7613 -2.7266 -2.7392 -2.7840 -2.7456 -2.7613 -2.8123\n",
       "-2.7778 -2.7603 -2.7586 -2.7621 -2.7264 -2.7403 -2.7828 -2.7451 -2.7613 -2.8128\n",
       "-2.7629 -2.7719 -2.7881 -2.7567 -2.7186 -2.7342 -2.7802 -2.7661 -2.7773 -2.8205\n",
       "-2.7630 -2.7726 -2.7879 -2.7574 -2.7176 -2.7360 -2.7803 -2.7648 -2.7785 -2.8198\n",
       "-2.7772 -2.7599 -2.7598 -2.7613 -2.7266 -2.7392 -2.7840 -2.7455 -2.7613 -2.8123\n",
       "-2.7772 -2.7596 -2.7599 -2.7602 -2.7266 -2.7393 -2.7854 -2.7465 -2.7623 -2.8118\n",
       "-2.7775 -2.7594 -2.7576 -2.7616 -2.7268 -2.7406 -2.7856 -2.7473 -2.7626 -2.8133\n",
       "-2.7644 -2.7734 -2.7887 -2.7560 -2.7164 -2.7352 -2.7774 -2.7660 -2.7785 -2.8201\n",
       "-2.7775 -2.7595 -2.7582 -2.7616 -2.7261 -2.7408 -2.7848 -2.7466 -2.7621 -2.8136\n",
       "-2.7766 -2.7607 -2.7601 -2.7613 -2.7271 -2.7392 -2.7846 -2.7454 -2.7610 -2.8119\n",
       "-2.7637 -2.7743 -2.7889 -2.7554 -2.7172 -2.7335 -2.7791 -2.7659 -2.7775 -2.8205\n",
       "-2.7876 -2.7618 -2.7577 -2.7607 -2.7298 -2.7405 -2.7859 -2.7429 -2.7545 -2.8090\n",
       "-2.7856 -2.7605 -2.7601 -2.7624 -2.7243 -2.7442 -2.7859 -2.7416 -2.7523 -2.8085\n",
       "\n",
       "Columns 10 to 15 \n",
       "-2.7676 -2.7991 -2.8075 -2.8007 -2.7917 -2.7695\n",
       "-2.7686 -2.7997 -2.8059 -2.8014 -2.7946 -2.7686\n",
       "-2.7688 -2.8000 -2.8058 -2.8013 -2.7945 -2.7688\n",
       "-2.7690 -2.7994 -2.8054 -2.8006 -2.7955 -2.7687\n",
       "-2.7791 -2.8000 -2.7665 -2.7885 -2.7936 -2.7617\n",
       "-2.7777 -2.8001 -2.7660 -2.7898 -2.7935 -2.7611\n",
       "-2.7688 -2.7999 -2.8058 -2.8013 -2.7945 -2.7687\n",
       "-2.7690 -2.7992 -2.8057 -2.8001 -2.7948 -2.7683\n",
       "-2.7674 -2.7976 -2.8066 -2.7995 -2.7937 -2.7690\n",
       "-2.7787 -2.8004 -2.7655 -2.7888 -2.7947 -2.7616\n",
       "-2.7678 -2.7975 -2.8076 -2.8003 -2.7928 -2.7693\n",
       "-2.7689 -2.7995 -2.8057 -2.8014 -2.7944 -2.7682\n",
       "-2.7774 -2.7987 -2.7672 -2.7900 -2.7933 -2.7635\n",
       "-2.7682 -2.8052 -2.8053 -2.7949 -2.7920 -2.7702\n",
       "-2.7672 -2.8032 -2.8088 -2.7979 -2.7931 -2.7708\n",
       "[torch.FloatTensor of size 15x16]"
      ]
     },
     "execution_count": 135,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
